import cv2 as cv
import numpy as np

import constants as cst


class MyDescriptorMatcher:

    def __init__(self, matcherType="BF_HAMMING", alpha=cst.ALPHA_MATCH if hasattr(cst, "ALPHA_MATCH") else 8.0,
                 crossCheck=True):
        """
        matcherType:
          - "BF_L2", "BF_L1", "BF_HAMMING", "BF_HAMMING2", "FLANN"
        alpha:
          - threshold multiplier for good matches (threshold = alpha * minDist)
        crossCheck:
          - only for BF matchers
        """
        self.myDescriptorMatcherType = None
        self.myDescriptorMatcher = None

        self.alpha = float(alpha)
        self.crossCheck = bool(crossCheck)

        self.changeDescriptorMatcher(matcherType)

    def changeDescriptorMatcher(self, matcherType):
        # Release previous matcher if any (BFMatcher supports .clear(), not .release() everywhere)
        if self.myDescriptorMatcher is not None:
            try:
                self.myDescriptorMatcher.clear()
            except Exception:
                pass

        self.myDescriptorMatcherType = matcherType

        # --- Brute Force variants ---
        if matcherType == "BF_L2":
            self.myDescriptorMatcher = cv.BFMatcher(cv.NORM_L2, crossCheck=self.crossCheck)

        elif matcherType == "BF_L1":
            self.myDescriptorMatcher = cv.BFMatcher(cv.NORM_L1, crossCheck=self.crossCheck)

        elif matcherType == "BF_HAMMING":
            self.myDescriptorMatcher = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=self.crossCheck)

        elif matcherType == "BF_HAMMING2":
            self.myDescriptorMatcher = cv.BFMatcher(cv.NORM_HAMMING2, crossCheck=self.crossCheck)

        # --- FLANN ---
        elif matcherType == "FLANN":
            # FLANN: attention
            # - Pour descripteurs float (SIFT/SURF): KDTree (algorithm=1)
            # - Pour descripteurs binaires (ORB/AKAZE/BRISK): LSH (algorithm=6)
            #
            # Ici on instancie un FLANN "binaire" (LSH) par défaut car ton TP utilise ORB/AKAZE.
            index_params = dict(algorithm=6, table_number=6, key_size=12, multi_probe_level=1)
            search_params = dict(checks=50)
            self.myDescriptorMatcher = cv.FlannBasedMatcher(index_params, search_params)

        else:
            raise ValueError(f"Unknown matcherType: {matcherType}")

    # -------------------------------------------------------------------------
    # Perform matching between descriptors and returns best matches
    # -------------------------------------------------------------------------
    def match(self, descriptors1, descriptors2):
        self.descriptors1 = descriptors1
        self.descriptors2 = descriptors2

        self.matches = []
        self.bestMatches = []

        self.minDist = float("inf")
        self.maxDist = 0.0

        # Robust guards
        if descriptors1 is None or descriptors2 is None:
            return self.bestMatches
        if len(descriptors1) == 0 or len(descriptors2) == 0:
            return self.bestMatches

        try:
            # FLANN + descripteurs binaires => OpenCV attend souvent des uint8,
            # mais certaines configs FLANN nécessitent float32. Ici, avec LSH, uint8 OK.
            self.matches = self.myDescriptorMatcher.match(descriptors1, descriptors2)
            # On peut trier pour stabilité
            self.matches = sorted(self.matches, key=lambda m: m.distance)

            print("Number of matches:", len(self.matches))

            # Compute min/max distance
            for m in self.matches:
                d = float(m.distance)
                if d < self.minDist:
                    self.minDist = d
                if d > self.maxDist:
                    self.maxDist = d

            # Avoid tiny minDist -> threshold too small
            if self.minDist < cst.EPSILON_DIST:
                self.minDist = float(cst.EPSILON_DIST)

            # TODO Compute threshold for good matches
            #threshold = 

            # Keep only good matches
            # TODO

        except Exception as e:
            print(e)
            print("Matching failed")

        return self.bestMatches

    # -------------------------------------------------------------------------
    # Draw best matches into an image and returns it 
    # -------------------------------------------------------------------------
    def drawMatchingResults(self, image1, image2, bestMatches, features1, features2):
        image = np.concatenate((image1, image2), axis=1)

        # If inputs are grayscale, convert to BGR for colored drawing
        if len(image.shape) == 2:
            image = cv.cvtColor(image, cv.COLOR_GRAY2BGR)

        for match in bestMatches:
            img1_idx = match.queryIdx
            img2_idx = match.trainIdx

            (x1, y1) = features1[img1_idx].pt
            (x2, y2) = features2[img2_idx].pt

            cv.circle(image, (int(x1), int(y1)), 4, (255, 0, 0), 1)
            cv.circle(image, (int(x2) + image1.shape[1], int(y2)), 4, (255, 0, 0), 1)

            cv.line(image,
                    (int(x1), int(y1)),
                    (int(x2) + image1.shape[1], int(y2)),
                    (255, 0, 0), 1)

        return image

    # -------------------------------------------------------------------------
    # Getting the name of the matching method
    # -------------------------------------------------------------------------
    def getMatcherName(self, matcherType=None):
        if matcherType is None:
            matcherType = self.myDescriptorMatcherType

        if matcherType == "BF_L2":
            return "Brute Force Matcher (L2)"
        elif matcherType == "BF_L1":
            return "Brute Force Matcher (L1)"
        elif matcherType == "BF_HAMMING":
            return "Brute Force Matcher (Hamming)"
        elif matcherType == "BF_HAMMING2":
            return "Brute Force Matcher (Hamming2)"
        elif matcherType == "FLANN":
            return "FLANN Based Matcher"
        else:
            return "Unknown Matcher"
